#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
plt.style.use('./MNRAS_Style.mplstyle')
import os

NEUTRINOS = []

class Gamma():
    def __init__(self):
        self.jd = 0
        self.flux = 0
        self.flux_err = 0
        self.upp = 0

def getMJD(MET_s,MET_e):
    MET = MET_s + (MET_e - MET_s)/2.
    MJD = 51910.0 + MET/86400.
    return MJD

def getGammaData(obj):
    gamma = []
    fop = open(os.path.join('../../data/3d/',obj+'_daily.csv'))
    for line in fop.readlines():
        if line.startswith("#"):
            continue
        sl = line.split(',')
        if float(sl[2]) >= 1e-5:
            continue
        obs = Gamma()
        obs.jd = float(sl[0])
        if obs.jd > 59976.5:
            continue
        TS       = float(sl[5])
        if  TS > 4:
            obs.flux     = float(sl[2])
            obs.flux_err = float(sl[3])
        else:
            obs.upp = float(sl[4])
        gamma.append(obs)
    fop.close()
    return gamma

def readCD(obj):
    CDs = []
    fop = open(obj+'_CD.dat')
    for line in fop.readlines():
        sl = line.split(',')
        MJD = float(sl[0])
        CD = float(sl[1])
        CD_err = float(sl[2])
        tel_band = sl[3].rstrip('\n')
        CDs.append((MJD,CD,CD_err,tel_band))
    fop.close()
    return CDs

def plot_w_CD(obj, CDs, yscale='linear'):
    data = getGammaData(obj)
    fluxes = []
    uppers = []
    for k,o in enumerate(data):
        if o.upp == 0:
            fluxes.append(o)
        else:
            uppers.append(o)
    
    fig, ax = plt.subplots(2, 1, sharex=True)
    fig.set_size_inches(17,10)
    
    plt.rc('text', usetex=True)
    
    ax[0].set_ylabel(r'Flux [${\rm10^{-6} ph\,\, sec^{-1}\, cm^{-2}}$]')

    max_fl = 0
    min_jd = 9e9
    max_jd = 0
    for o in fluxes:
        ax[0].errorbar([o.jd], [1e6 * o.flux], yerr=[1e6 * o.flux_err], markersize=0.5, elinewidth=0.3, color="k")
        ax[0].plot([o.jd],[1e6 * o.flux],'ko', markersize=0.5)
        max_fl = max(max_fl,1e6 * o.flux)
        min_jd = min(min_jd,o.jd)
        max_jd = max(max_jd,o.jd)

    for u in uppers:
        ax[0].errorbar([u.jd],[1e6 * u.upp], yerr=-1e-6, uplims=True, capsize=1.8, markeredgewidth=0, color='#AA3377')
        min_jd = min(min_jd,u.jd)
        max_jd = max(max_jd,u.jd)

    max_fl = max(0,max_fl)
    max_fl = max_fl*1.1

    plt.yscale(yscale)
    if yscale == 'log':
        ax[0].set_ylim([1e-3,max_fl])
    else:
        ax[0].set_ylim([0,max_fl])
    ax[0].set_xlim([min_jd,max_jd])
    
    ax[0].vlines(55700, 0, 1, color = '#4477AA')
    ax[0].vlines(56225, 0, 1, color = '#4477AA')
    ax[0].text(55915, 0.62, '1', fontsize = 30, color = '#4477AA')
    ax[0].vlines(59115, 0, 1, color = '#4477AA')
    ax[0].vlines(59542, 0, 1, color = '#4477AA')
    ax[0].text(59286, 0.62, '2', fontsize = 30, color = '#4477AA')
    
    for JDneut in NEUTRINOS:
        plt.arrow(JDneut, 3.5, 0, -1.5, width = 9, head_length = 0.2, facecolor = 'red', edgecolor='red', zorder=3)

    ax2 = ax[0].twiny()
    ax2.set_xlim(ax[0].get_xlim())
    ax2.set_xticks(      [54832, 55197, 55562, 55927, 56293, 56658, 57023, 57388, 57754, 58119, 58484, 58849, 59215, 59580, 59945])
    ax2.set_xticklabels(['2009','2010','2011','2012','2013','2014','2015','2016','2017','2018','2019','2020','2021','2022','2023'])
    ax2.grid(None)
    
    
    colors = {'ASAS-SN g'    : "#332288",
              'ASAS-SN V'    : "#88CCEE",
              'Pan-STARRS g' : "#44AA99",
              'Pan-STARRS r' : "#117733",
              'ZTF i'        : "#882255",
              'ZTF g'        : "#DDCC77",
              'ZTF r'        : "#CC6677",
              'KAIT WL'      : "#882255"}

    plotted_flag = {'ASAS-SN g'    : 0,
                    'ASAS-SN V'    : 0,
                    'Pan-STARRS g' : 0,
                    'Pan-STARRS r' : 0,
                    'ZTF i'        : 0,
                    'ZTF g'        : 0,
                    'ZTF r'        : 0,
                    'KAIT WL'      : 0}
    ax[1].set_ylim([2,50])
    ax[1].set_yscale('log')
    for CD in CDs:
        MJD,CD,CD_err,tel_band = CD
        ax[1].errorbar([MJD], [CD], yerr=[CD_err], markersize=1, elinewidth=0.3, color=colors[tel_band])
        if plotted_flag[tel_band] > 0:
            # this is for the label in legend
            ax[1].plot([MJD],[CD], 'o', markersize=3, color=colors[tel_band])
        else:
            ax[1].plot([MJD],[CD], 'o', markersize=3, color=colors[tel_band], label=tel_band)
            plotted_flag[tel_band] = 1

    ax[1].vlines(55700, ax[1].get_ylim()[0], ax[1].get_ylim()[1], color = '#4477AA')
    ax[1].vlines(56225, ax[1].get_ylim()[0], ax[1].get_ylim()[1], color = '#4477AA')
    
    ax[1].vlines(59115, ax[1].get_ylim()[0], ax[1].get_ylim()[1], color = '#4477AA')
    ax[1].vlines(59542, ax[1].get_ylim()[0], ax[1].get_ylim()[1], color = '#4477AA')
    
    ax[1].set_xlabel('MJD')
    ax[1].set_ylabel('CD')
    
    ax[1].legend(loc=2)
    plt.savefig(obj + '_all.pdf',bbox_inches='tight')
    plt.clf()
    plt.cla()
    plt.close()
    return

if __name__ == "__main__":
    CDs = readCD('J1748.6+7005')
    plot_w_CD('J1748.6+7005',CDs,'linear')
